{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2da16609",
   "metadata": {},
   "source": [
    "# SageMaker JumpStart - Deploy Chronos-2 endpoints to AWS for production use"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f776b6f",
   "metadata": {},
   "source": [
    "In this demo notebook, we will walk through the process of using the **SageMaker Python SDK** to deploy a **Chronos-2** model to a cloud endpoint on AWS. To simplify deployment, we will leverage **SageMaker JumpStart**.\n",
    "\n",
    "### Why Deploy to an Endpoint?\n",
    "So far, we’ve seen how to run models locally, which is useful for experimentation. However, in a production setting, a forecasting model is typically just one component of a larger system. Running models locally doesn’t scale well and lacks the reliability needed for real-world applications.\n",
    "\n",
    "To address this, we deploy models as **endpoints** on AWS. An endpoint acts as a **hosted service**—we can send it requests (containing time series data), and it returns forecasts in response. This allows seamless integration into production workflows, ensuring scalability and real-time inference."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c77da85",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-warning\">\n",
    "<b>⚠️ Looking for Chronos-Bolt or original Chronos?</b><br>\n",
    "This notebook covers <b>Chronos-2</b>, our latest and recommended model. For documentation on older models (Chronos-Bolt and original Chronos), see the <a href=\"https://github.com/amazon-science/chronos-forecasting/blob/v1.5.3/notebooks/deploy-chronos-bolt-to-amazon-sagemaker.ipynb\"><b>legacy deployment walkthrough</b></a>.\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59912f1f",
   "metadata": {},
   "source": [
    "### Chronos-2 vs. Previous Models\n",
    "\n",
    "**Chronos-2** is a foundation model for time series forecasting that builds on [Chronos](https://arxiv.org/abs/2403.07815) and [Chronos-Bolt](https://aws.amazon.com/blogs/machine-learning/fast-and-accurate-zero-shot-forecasting-with-chronos-bolt-and-autogluon/). It offers significant improvements in capabilities, better accuracy, and can handle diverse forecasting scenarios not supported by earlier models.\n",
    "\n",
    "| Capability | Chronos-2 | Chronos-Bolt | Chronos |\n",
    "|------------|-----------|--------------|----------|\n",
    "| Univariate Forecasting | ✅ | ✅ | ✅ |\n",
    "| Cross-learning across items | ✅ | ❌ | ❌ |\n",
    "| Multivariate Forecasting | ✅ | ❌ | ❌ |\n",
    "| Past-only (real/categorical) covariates | ✅ | ❌ | ❌ |\n",
    "| Known future (real/categorical) covariates | ✅ | 🧩 | ❌ |\n",
    "| Max. Context Length | 8192 | 2048 | 512 |\n",
    "| Max. Prediction Length | 1024 | 64 | 64 |\n",
    "\n",
    "🧩 Chronos-Bolt does not natively support future covariates, but they can be combined with external covariate regressors (see [AutoGluon tutorial](https://auto.gluon.ai/stable/tutorials/timeseries/forecasting-chronos.html#incorporating-the-covariates)). This only models per-timestep effects, not effects across time. In contrast, Chronos-2 supports all covariate types natively."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0583a66",
   "metadata": {},
   "source": [
    "## Deploy the model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "690d9093",
   "metadata": {},
   "source": [
    "First, update the SageMaker SDK to access the latest models:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4ed0fb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -U -q sagemaker"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a07054a7",
   "metadata": {},
   "source": [
    "We create a `JumpStartModel` with the necessary configuration based on the model ID. The key parameters are:\n",
    "- `model_id`: Specifies the model to use. We use `pytorch-forecasting-chronos-2` for the [Chronos-2](https://github.com/amazon-science/chronos-forecasting) model.\n",
    "- `instance_type`: Defines the AWS instance for serving the endpoint. Chronos-2 currently requires a **GPU instance** from the `ml.g5`, `ml.g6`, `ml.g6e`, or `ml.g4dn` families with a single GPU. The model does not benefit from multi-GPU instances. **CPU support is coming soon**.\n",
    "\n",
    "   You can check the pricing for different SageMaker instance types for real-time inference [here](https://aws.amazon.com/sagemaker-ai/pricing/).\n",
    "\n",
    "The `JumpStartModel` will automatically set the necessary attributes such as `image_uri` based on the chosen `model_id` and `instance_type`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffbae4f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.jumpstart.model import JumpStartModel\n",
    "\n",
    "model = JumpStartModel(\n",
    "    model_id=\"pytorch-forecasting-chronos-2\",\n",
    "    instance_type=\"ml.g5.2xlarge\",\n",
    "    # You might need to provide the SageMaker execution role to ensure necessary AWS resources are accessible\n",
    "    # role=\"arn:aws:iam::123456789012:role/service-role/AmazonSageMaker-ExecutionRole-XXXXXXXXXXXXXXX\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb864ee1",
   "metadata": {},
   "source": [
    "Next, we deploy the model and create an endpoint. Deployment typically takes a few minutes, as SageMaker provisions the instance, loads the model, and sets up the endpoint for inference.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fd0068b",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor = model.deploy()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4dd66e7",
   "metadata": {},
   "source": [
    "> **Note:** Once the endpoint is deployed, it remains active and incurs charges on your AWS account until it is deleted. The cost depends on factors such as the instance type, the region where the endpoint is hosted, and the duration it remains running. To avoid unnecessary charges, make sure to delete the endpoint when it is no longer needed. For detailed pricing information, refer to the [SageMaker AI pricing page](https://aws.amazon.com/sagemaker-ai/pricing/)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48ce52ef",
   "metadata": {},
   "source": [
    "Alternatively, you can connect to an existing endpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a09367fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from sagemaker.predictor import Predictor\n",
    "# from sagemaker.serializers import JSONSerializer\n",
    "# from sagemaker.deserializers import JSONDeserializer\n",
    "\n",
    "# predictor = Predictor(\n",
    "#     \"NAME_OF_EXISTING_ENDPOINT\",\n",
    "#     serializer=JSONSerializer(),\n",
    "#     deserializer=JSONDeserializer(),\n",
    "# )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3def973",
   "metadata": {},
   "source": [
    "## Querying the endpoint"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fbe39f9",
   "metadata": {},
   "source": [
    "We can now invoke the endpoint to make a forecast. We send a **payload** to the endpoint, which includes historical time series values and configuration parameters, such as the prediction length. The endpoint processes this input and returns a **response** containing the forecasted values based on the provided data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1ae7b33e",
   "metadata": {
    "lines_to_next_cell": 1
   },
   "outputs": [],
   "source": [
    "# Define a utility function to print the response in a pretty format\n",
    "from pprint import pformat\n",
    "\n",
    "\n",
    "def nested_round(data, decimals=2):\n",
    "    \"\"\"Round numbers, including nested dicts and list.\"\"\"\n",
    "    if isinstance(data, float):\n",
    "        return round(data, decimals)\n",
    "    elif isinstance(data, list):\n",
    "        return [nested_round(item, decimals) for item in data]\n",
    "    elif isinstance(data, dict):\n",
    "        return {key: nested_round(value, decimals) for key, value in data.items()}\n",
    "    else:\n",
    "        return data\n",
    "\n",
    "\n",
    "def pretty_format(data):\n",
    "    return pformat(nested_round(data), width=150, sort_dicts=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb1629c7",
   "metadata": {},
   "source": [
    "### Univariate forecasting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "320a9c49",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'predictions': [{'mean': [-0.36, 4.03, 5.31, 2.44, -2.47, -5.09, -4.31, 0.07, 4.41, 5.16],\n",
      "                  '0.1': [-1.69, 2.84, 4.0, 0.97, -3.77, -6.19, -5.34, -1.77, 2.55, 3.61],\n",
      "                  '0.5': [-0.36, 4.03, 5.31, 2.44, -2.47, -5.09, -4.31, 0.07, 4.41, 5.16],\n",
      "                  '0.9': [1.03, 5.0, 6.31, 3.81, -0.85, -3.89, -2.89, 1.84, 5.59, 6.44]}]}\n"
     ]
    }
   ],
   "source": [
    "payload = {\n",
    "    \"inputs\": [\n",
    "        {\"target\": [0.0, 4.0, 5.0, 1.5, -3.0, -5.0, -3.0, 1.5, 5.0, 4.0, 0.0, -4.0, -5.0, -1.5, 3.0, 5.0, 3.0, -1.5, -5.0, -4.0]},\n",
    "    ],\n",
    "    \"parameters\": {\n",
    "        \"prediction_length\": 10\n",
    "    }\n",
    "}\n",
    "response = predictor.predict(payload)\n",
    "print(pretty_format(response))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f37d392",
   "metadata": {},
   "source": [
    "A payload may also contain **multiple time series**, potentially including `start` and `item_id` fields."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "14c62c74",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'predictions': [{'mean': [1.7, 1.95, 1.66, 1.55, 1.84],\n",
      "                  '0.1': [0.28, 0.32, -0.08, -0.35, -0.18],\n",
      "                  '0.5': [1.7, 1.95, 1.66, 1.55, 1.84],\n",
      "                  '0.9': [3.09, 3.77, 3.62, 3.58, 4.22],\n",
      "                  'item_id': 'product_A',\n",
      "                  'start': '2024-01-01T10:00:00'},\n",
      "                 {'mean': [-1.21, -1.4, -1.27, -1.34, -1.27],\n",
      "                  '0.1': [-4.19, -5.84, -6.38, -7.53, -8.0],\n",
      "                  '0.5': [-1.21, -1.4, -1.27, -1.34, -1.27],\n",
      "                  '0.9': [2.02, 2.92, 3.55, 4.62, 5.66],\n",
      "                  'item_id': 'product_B',\n",
      "                  'start': '2024-02-02T10:00:00'}]}\n"
     ]
    }
   ],
   "source": [
    "payload = {\n",
    "    \"inputs\": [\n",
    "        {\n",
    "            \"target\": [1.0, 2.0, 3.0, 2.0, 0.5, 2.0, 3.0, 2.0, 1.0],\n",
    "            \"item_id\": \"product_A\",\n",
    "            \"start\": \"2024-01-01T01:00:00\",\n",
    "        },\n",
    "        {\n",
    "            \"target\": [5.4, 3.0, 3.0, 2.0, 1.5, 2.0, -1.0],\n",
    "            \"item_id\": \"product_B\",\n",
    "            \"start\": \"2024-02-02T03:00:00\",\n",
    "        },\n",
    "    ],\n",
    "    \"parameters\": {\n",
    "        \"prediction_length\": 5,\n",
    "        \"freq\": \"1h\",\n",
    "        \"quantile_levels\": [0.1, 0.5, 0.9],\n",
    "        \"batch_size\": 2,\n",
    "    },\n",
    "}\n",
    "response = predictor.predict(payload)\n",
    "print(pretty_format(response))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ae41cdc",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "source": [
    "### Forecasting with covariates\n",
    "\n",
    "Chronos-2 models also support forecasting with **covariates** (a.k.a. exogenous features or related time series). These can be provided using the `past_covariates` and `future_covariates` keys.\n",
    "\n",
    "**Note:** If you only provide `past_covariates` without matching keys in `future_covariates`, the model will treat them as past-only covariates (features that are only available historically but not in the future).\n",
    "If future values of covariates are available, it is recommended to provide them in `future_covariates` as this typically results in more accurate forecasts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e57f1541",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'predictions': [{'mean': [1.73, 2.09, 1.73], '0.1': [0.36, 0.6, 0.17], '0.5': [1.73, 2.09, 1.73], '0.9': [3.11, 3.8, 3.52]},\n",
      "                 {'mean': [-0.61, -0.41, -1.43], '0.1': [-4.16, -5.59, -7.53], '0.5': [-0.61, -0.41, -1.43], '0.9': [3.12, 4.56, 3.91]}]}\n"
     ]
    }
   ],
   "source": [
    "payload = {\n",
    "    \"inputs\": [\n",
    "        {\n",
    "            \"target\": [1.0, 2.0, 3.0, 2.0, 0.5, 2.0, 3.0, 2.0, 1.0],\n",
    "            # past_covariates must have the same length as \"target\"\n",
    "            \"past_covariates\": {\n",
    "                \"feat_1\": [3.0, 6.0, 9.0, 6.0, 1.5, 6.0, 9.0, 6.0, 3.0],\n",
    "                # Categorical covariates should be provided as strings\n",
    "                \"feat_2\": [\"A\", \"B\", \"B\", \"B\", \"A\", \"A\", \"A\", \"A\", \"B\"],\n",
    "                # feat_3 is a past-only covariate (not present in future_covariates)\n",
    "                \"feat_3\": [10.0, 20.0, 30.0, 20.0, 5.0, 20.0, 30.0, 20.0, 10.0],\n",
    "            },\n",
    "            # future_covariates must have length equal to \"prediction_length\"\n",
    "            \"future_covariates\": {\n",
    "                \"feat_1\": [2.5, 2.2, 3.3],\n",
    "                \"feat_2\": [\"B\", \"A\", \"A\"],\n",
    "            },\n",
    "        },\n",
    "        {\n",
    "            \"target\": [5.4, 3.0, 3.0, 2.0, 1.5, 2.0, -1.0],\n",
    "            \"past_covariates\": {\n",
    "                \"feat_1\": [0.6, 1.2, 1.8, 1.2, 0.3, 1.2, 1.8],\n",
    "                \"feat_2\": [\"A\", \"B\", \"B\", \"B\", \"A\", \"A\", \"A\"],\n",
    "                \"feat_3\": [5.4, 3.0, 3.0, 2.0, 1.5, 2.0, -1.0],\n",
    "            },\n",
    "            \"future_covariates\": {\n",
    "                \"feat_1\": [1.2, 0.3, 4.4],\n",
    "                \"feat_2\": [\"A\", \"B\", \"A\"],\n",
    "            },\n",
    "        },\n",
    "    ],\n",
    "    \"parameters\": {\n",
    "        \"prediction_length\": 3,\n",
    "        \"quantile_levels\": [0.1, 0.5, 0.9],\n",
    "    },\n",
    "}\n",
    "response = predictor.predict(payload)\n",
    "print(pretty_format(response))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76c88a22",
   "metadata": {},
   "source": [
    "### Multivariate forecasting\n",
    "\n",
    "Chronos-2 also supports **multivariate forecasting**, where multiple related time series are forecasted jointly. For multivariate forecasting, provide the target as a list of lists, where each inner list represents one dimension of the multivariate series."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b73609be",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'predictions': [{'mean': [[3.66, 3.55, 3.5, 3.42], [2.0, 2.05, 2.19, 2.23], [3.33, 3.27, 3.25, 3.22]],\n",
      "                  '0.1': [[1.98, 1.52, 1.17, 0.88], [0.84, 0.18, 0.0, -0.25], [2.5, 2.27, 2.08, 1.94]],\n",
      "                  '0.5': [[3.66, 3.55, 3.5, 3.42], [2.0, 2.05, 2.19, 2.23], [3.33, 3.27, 3.25, 3.22]],\n",
      "                  '0.9': [[5.75, 6.25, 6.59, 7.0], [3.8, 4.47, 4.88, 5.31], [4.38, 4.62, 4.78, 5.0]]}]}\n"
     ]
    }
   ],
   "source": [
    "payload = {\n",
    "    \"inputs\": [\n",
    "        {\n",
    "            # For multivariate forecasting, target is a list of lists\n",
    "            # Each inner list represents one dimension with the same length\n",
    "            # np.array(target) would have shape [num_dimensions, length]\n",
    "            \"target\": [\n",
    "                [1.0, 2.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0],  # Dimension 1\n",
    "                [5.0, 4.0, 3.0, 4.0, 5.0, 4.0, 3.0, 2.0],  # Dimension 2\n",
    "                [2.0, 2.5, 3.0, 2.5, 2.0, 2.5, 3.0, 3.5],  # Dimension 3\n",
    "            ],\n",
    "        },\n",
    "    ],\n",
    "    \"parameters\": {\n",
    "        \"prediction_length\": 4,\n",
    "        \"quantile_levels\": [0.1, 0.5, 0.9],\n",
    "    },\n",
    "}\n",
    "response = predictor.predict(payload)\n",
    "print(pretty_format(response))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9479e7a3",
   "metadata": {},
   "source": [
    "## Endpoint API\n",
    "So far, we have explored several examples of querying the endpoint with different payload structures. Below is a comprehensive API specification detailing all supported parameters, their meanings, and how they affect the model’s predictions.\n",
    "\n",
    "* **inputs** (required): List with at most 1000 time series that need to be forecasted. Each time series is represented by a dictionary with the following keys:\n",
    "    * **target** (required): Observed time series values.\n",
    "        - For univariate forecasting: List of numeric values.\n",
    "        - For multivariate forecasting: List of lists, where each inner list represents one dimension. All dimensions must have the same length. If converted to a numpy array via `np.array(target)`, the shape would be `[num_dimensions, length]`.\n",
    "        - It is recommended that each time series contains at least 30 observations.\n",
    "        - If any time series contains fewer than 5 observations, an error will be raised.\n",
    "    * **item_id**: String that uniquely identifies each time series.\n",
    "        - If provided, the ID must be unique for each time series.\n",
    "        - If provided, then the endpoint response will also include the **item_id** field for each forecast.\n",
    "    * **start**: Timestamp of the first time series observation in ISO format (`YYYY-MM-DD` or `YYYY-MM-DDThh:mm:ss`).\n",
    "        - If **start** field is provided, then **freq** must also be provided as part of **parameters**.\n",
    "        - If provided, then the endpoint response will also include the **start** field indicating the first timestamp of each forecast.\n",
    "    * **past_covariates**: Dictionary containing the past values of the covariates for this time series.\n",
    "        - Each key in **past_covariates** correspond to the name of the covariate. Each value must be an array consisting of all-numeric or all-string values, with the length equal to the length of the **target**.\n",
    "        - Covariates that appear only in **past_covariates** (and not in **future_covariates**) are treated as past-only covariates.\n",
    "    * **future_covariates**: Dictionary containing the future values of the covariates for this time series (values during the forecast horizon).\n",
    "        - Each key in **future_covariates** correspond to the name of the covariate. Each value must be an array consisting of all-numeric or all-string values, with the length equal to **prediction_length**.\n",
    "        - Covariates that appear in both **past_covariates** and **future_covariates** are treated as known future covariates.\n",
    "* **parameters**: Optional parameters to configure the model.\n",
    "    * **prediction_length**: Integer corresponding to the number of future time series values that need to be predicted. Defaults to `1`. Values up to `1024` are supported.\n",
    "    * **quantile_levels**: List of floats in range (0, 1) specifying which quantiles should should be included in the probabilistic forecast. Defaults to `[0.1, 0.5, 0.9]`.\n",
    "        - Chronos-2 natively supports quantile levels in range `[0.01, 0.99]`. Predictions outside the range will be clipped.\n",
    "    * **freq**: Frequency of the time series observations in [pandas-compatible format](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases). For example, `1h` for hourly data or `2W` for bi-weekly data.\n",
    "        - If **freq** is provided, then **start** must also be provided for each time series in **inputs**.\n",
    "    * **batch_size**: Number of time series processed in parallel by the model. Larger values speed up inference but may lead to out of memory errors. Defaults to `256`.\n",
    "    * **predict_batches_jointly**: If `True`, the model will apply group attention to all items in the batch, instead of processing each item separately (described as \"full cross-learning mode\" in the [technical report](https://www.arxiv.org/abs/2510.15821)). This may produce more accurate forecasts at the cost of lower inference speed. Defaults to `False`.\n",
    "\n",
    "All keys not marked with (required) are optional.\n",
    "\n",
    "The endpoint response contains the probabilistic (quantile) forecast for each time series included in the request."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ee1e161",
   "metadata": {},
   "source": [
    "## Working with long-format data frames"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a744884c",
   "metadata": {},
   "source": [
    "The endpoint communicates using JSON format for both input and output. However, in practice, time series data is often stored in a **long-format data frame** (where each row represents a timestamp for a specific item).\n",
    "\n",
    "In the following example, we demonstrate how to:\n",
    "\n",
    "1. Convert a long-format data frame into the JSON payload format required by the endpoint.\n",
    "2. Send the request and retrieve predictions.\n",
    "3. Convert the response back into a long-format data frame for further analysis.\n",
    "\n",
    "First, we load an example dataset in long data frame format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6ecba0ca",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>item_id</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>scaled_price</th>\n",
       "      <th>promotion_email</th>\n",
       "      <th>promotion_homepage</th>\n",
       "      <th>unit_sales</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-01</td>\n",
       "      <td>0.879130</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>636.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-08</td>\n",
       "      <td>0.994517</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>123.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-15</td>\n",
       "      <td>1.005513</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>391.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-22</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>339.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-29</td>\n",
       "      <td>0.883309</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>661.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    item_id  timestamp  scaled_price  promotion_email  promotion_homepage  \\\n",
       "0  1062_101 2018-01-01      0.879130              0.0                 0.0   \n",
       "1  1062_101 2018-01-08      0.994517              0.0                 0.0   \n",
       "2  1062_101 2018-01-15      1.005513              0.0                 0.0   \n",
       "3  1062_101 2018-01-22      1.000000              0.0                 0.0   \n",
       "4  1062_101 2018-01-29      0.883309              0.0                 0.0   \n",
       "\n",
       "   unit_sales  \n",
       "0       636.0  \n",
       "1       123.0  \n",
       "2       391.0  \n",
       "3       339.0  \n",
       "4       661.0  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv(\n",
    "    \"https://autogluon.s3.amazonaws.com/datasets/timeseries/grocery_sales/test.csv\",\n",
    "    parse_dates=[\"timestamp\"],\n",
    ")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4288470c",
   "metadata": {},
   "source": [
    "We split the data into two parts:\n",
    "- Past data, including historic values of the target column and the covariates.\n",
    "- Future data that contains the future values of the covariates during the forecast horizon."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d95eb6d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction_length = 8\n",
    "target_col = \"unit_sales\"\n",
    "freq = pd.infer_freq(df[df.item_id == df.item_id[0]][\"timestamp\"])\n",
    "\n",
    "past_df = df.groupby(\"item_id\").head(-prediction_length)\n",
    "future_df = df.groupby(\"item_id\").tail(prediction_length).drop(columns=[target_col])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c2482ffd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>item_id</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>scaled_price</th>\n",
       "      <th>promotion_email</th>\n",
       "      <th>promotion_homepage</th>\n",
       "      <th>unit_sales</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-01</td>\n",
       "      <td>0.879130</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>636.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-08</td>\n",
       "      <td>0.994517</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>123.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-15</td>\n",
       "      <td>1.005513</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>391.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-22</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>339.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-29</td>\n",
       "      <td>0.883309</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>661.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    item_id  timestamp  scaled_price  promotion_email  promotion_homepage  \\\n",
       "0  1062_101 2018-01-01      0.879130              0.0                 0.0   \n",
       "1  1062_101 2018-01-08      0.994517              0.0                 0.0   \n",
       "2  1062_101 2018-01-15      1.005513              0.0                 0.0   \n",
       "3  1062_101 2018-01-22      1.000000              0.0                 0.0   \n",
       "4  1062_101 2018-01-29      0.883309              0.0                 0.0   \n",
       "\n",
       "   unit_sales  \n",
       "0       636.0  \n",
       "1       123.0  \n",
       "2       391.0  \n",
       "3       339.0  \n",
       "4       661.0  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "past_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c47905f8",
   "metadata": {
    "lines_to_next_cell": 1
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>item_id</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>scaled_price</th>\n",
       "      <th>promotion_email</th>\n",
       "      <th>promotion_homepage</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-06-11</td>\n",
       "      <td>1.005425</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-06-18</td>\n",
       "      <td>1.005454</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-06-25</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-07-02</td>\n",
       "      <td>1.005513</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-07-09</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     item_id  timestamp  scaled_price  promotion_email  promotion_homepage\n",
       "23  1062_101 2018-06-11      1.005425              0.0                 0.0\n",
       "24  1062_101 2018-06-18      1.005454              0.0                 0.0\n",
       "25  1062_101 2018-06-25      1.000000              0.0                 0.0\n",
       "26  1062_101 2018-07-02      1.005513              0.0                 0.0\n",
       "27  1062_101 2018-07-09      1.000000              0.0                 0.0"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "future_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9f54de7",
   "metadata": {},
   "source": [
    "We can now convert this data into a JSON payload."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "55bbad68",
   "metadata": {
    "lines_to_next_cell": 1
   },
   "outputs": [],
   "source": [
    "def convert_df_to_payload(\n",
    "    past_df,\n",
    "    future_df=None,\n",
    "    prediction_length=1,\n",
    "    freq=\"D\",\n",
    "    target=\"target\",\n",
    "    id_column=\"item_id\",\n",
    "    timestamp_column=\"timestamp\",\n",
    "):\n",
    "    \"\"\"\n",
    "    Converts past and future DataFrames into JSON payload format for the Chronos endpoint.\n",
    "\n",
    "    Args:\n",
    "        past_df (pd.DataFrame): Historical data with `target`, `timestamp_column`, and `id_column`.\n",
    "        future_df (pd.DataFrame, optional): Future covariates with `timestamp_column` and `id_column`.\n",
    "            Covariates in past_df but not in future_df are treated as past-only covariates.\n",
    "        prediction_length (int): Number of future time steps to predict.\n",
    "        freq (str): Pandas-compatible frequency of the time series.\n",
    "        target (str or list[str]): Column name(s) for target values.\n",
    "            Use a string for univariate forecasting or a list of strings for multivariate forecasting.\n",
    "        id_column (str): Column name for item IDs.\n",
    "        timestamp_column (str): Column name for timestamps.\n",
    "\n",
    "    Returns:\n",
    "        dict: JSON payload formatted for the Chronos endpoint.\n",
    "    \"\"\"\n",
    "    past_df = past_df.sort_values([id_column, timestamp_column])\n",
    "    if future_df is not None:\n",
    "        future_df = future_df.sort_values([id_column, timestamp_column])\n",
    "\n",
    "    target_cols = [target] if isinstance(target, str) else target\n",
    "    past_covariate_cols = list(past_df.columns.drop([*target_cols, id_column, timestamp_column]))\n",
    "    future_covariate_cols = [] if future_df is None else [col for col in past_covariate_cols if col in future_df.columns]\n",
    "\n",
    "    inputs = []\n",
    "    for item_id, past_group in past_df.groupby(id_column):\n",
    "        if len(target_cols) > 1:\n",
    "            target_values = [past_group[col].tolist() for col in target_cols]\n",
    "            series_length = len(target_values[0])\n",
    "        else:\n",
    "            target_values = past_group[target_cols[0]].tolist()\n",
    "            series_length = len(target_values)\n",
    "\n",
    "        if series_length < 5:\n",
    "            raise ValueError(f\"Time series '{item_id}' has fewer than 5 observations.\")\n",
    "\n",
    "        series_dict = {\n",
    "            \"target\": target_values,\n",
    "            \"item_id\": str(item_id),\n",
    "            \"start\": past_group[timestamp_column].iloc[0].isoformat(),\n",
    "        }\n",
    "\n",
    "        if past_covariate_cols:\n",
    "            series_dict[\"past_covariates\"] = past_group[past_covariate_cols].to_dict(orient=\"list\")\n",
    "\n",
    "        if future_covariate_cols:\n",
    "            future_group = future_df[future_df[id_column] == item_id]\n",
    "            if len(future_group) != prediction_length:\n",
    "                raise ValueError(\n",
    "                    f\"future_df must contain exactly {prediction_length=} values for each item_id from past_df \"\n",
    "                    f\"(got {len(future_group)=}) for {item_id=}\"\n",
    "                )\n",
    "            series_dict[\"future_covariates\"] = future_group[future_covariate_cols].to_dict(orient=\"list\")\n",
    "\n",
    "        inputs.append(series_dict)\n",
    "\n",
    "    return {\n",
    "        \"inputs\": inputs,\n",
    "        \"parameters\": {\"prediction_length\": prediction_length, \"freq\": freq},\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d5226957",
   "metadata": {},
   "outputs": [],
   "source": [
    "payload = convert_df_to_payload(\n",
    "    past_df,\n",
    "    future_df,\n",
    "    prediction_length=prediction_length,\n",
    "    freq=freq,\n",
    "    target=\"unit_sales\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4611c3e6",
   "metadata": {},
   "source": [
    "We can now send the payload to the endpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "504a731e",
   "metadata": {
    "lines_to_next_cell": 1
   },
   "outputs": [],
   "source": [
    "response = predictor.predict(payload)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "742be985",
   "metadata": {},
   "source": [
    "Note how Chronos-2 generated predictions for >300 time series in the dataset (with covariates!) in less than 2 seconds.\n",
    "\n",
    "Finally, we can convert the response back to a long-format data frame."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a48807f6",
   "metadata": {
    "lines_to_next_cell": 1
   },
   "outputs": [],
   "source": [
    "def convert_response_to_df(response, freq=\"D\"):\n",
    "    \"\"\"\n",
    "    Converts a JSON response from the Chronos endpoint into a long-format DataFrame.\n",
    "\n",
    "    Args:\n",
    "        response (dict): JSON response containing forecasts.\n",
    "        freq (str): Pandas-compatible frequency of the time series.\n",
    "\n",
    "    Returns:\n",
    "        pd.DataFrame: Long-format DataFrame with timestamps, item_id, and forecasted values.\n",
    "            For multivariate forecasts, creates separate rows for each target dimension (target_1, target_2, etc.).\n",
    "    \"\"\"\n",
    "    dfs = []\n",
    "    for forecast in response[\"predictions\"]:\n",
    "        if isinstance(forecast[\"mean\"], list) and isinstance(forecast[\"mean\"][0], list):\n",
    "            # Multivariate forecast\n",
    "            timestamps = pd.date_range(forecast[\"start\"], freq=freq, periods=len(forecast[\"mean\"][0]))\n",
    "            for dim_idx in range(len(forecast[\"mean\"])):\n",
    "                dim_data = {\"item_id\": forecast.get(\"item_id\"), \"timestamp\": timestamps, \"target\": f\"target_{dim_idx + 1}\"}\n",
    "                for key, value in forecast.items():\n",
    "                    if key not in [\"item_id\", \"start\"]:\n",
    "                        dim_data[key] = value[dim_idx]\n",
    "                dfs.append(pd.DataFrame(dim_data))\n",
    "        else:\n",
    "            # Univariate forecast\n",
    "            forecast_df = pd.DataFrame(forecast).drop(columns=[\"start\"])\n",
    "            forecast_df[\"timestamp\"] = pd.date_range(forecast[\"start\"], freq=freq, periods=len(forecast_df))\n",
    "            # Reorder columns to have item_id and timestamp first\n",
    "            cols = [\"item_id\", \"timestamp\"] + [c for c in forecast_df.columns if c not in [\"item_id\", \"timestamp\"]]\n",
    "            forecast_df = forecast_df[cols]\n",
    "            dfs.append(forecast_df)\n",
    "\n",
    "    return pd.concat(dfs, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ce0cf954",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>item_id</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>mean</th>\n",
       "      <th>0.1</th>\n",
       "      <th>0.5</th>\n",
       "      <th>0.9</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-06-11</td>\n",
       "      <td>320.0</td>\n",
       "      <td>186.0</td>\n",
       "      <td>320.0</td>\n",
       "      <td>488.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-06-18</td>\n",
       "      <td>318.0</td>\n",
       "      <td>175.0</td>\n",
       "      <td>318.0</td>\n",
       "      <td>496.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-06-25</td>\n",
       "      <td>316.0</td>\n",
       "      <td>169.0</td>\n",
       "      <td>316.0</td>\n",
       "      <td>508.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-07-02</td>\n",
       "      <td>316.0</td>\n",
       "      <td>171.0</td>\n",
       "      <td>316.0</td>\n",
       "      <td>506.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-07-09</td>\n",
       "      <td>310.0</td>\n",
       "      <td>165.0</td>\n",
       "      <td>310.0</td>\n",
       "      <td>506.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    item_id  timestamp   mean    0.1    0.5    0.9\n",
       "0  1062_101 2018-06-11  320.0  186.0  320.0  488.0\n",
       "1  1062_101 2018-06-18  318.0  175.0  318.0  496.0\n",
       "2  1062_101 2018-06-25  316.0  169.0  316.0  508.0\n",
       "3  1062_101 2018-07-02  316.0  171.0  316.0  506.0\n",
       "4  1062_101 2018-07-09  310.0  165.0  310.0  506.0"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "forecast_df = convert_response_to_df(response, freq=freq)\n",
    "forecast_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e89cbc36",
   "metadata": {},
   "source": [
    "## Clean up the endpoint\n",
    "Don't forget to clean up resources when finished to avoid unnecessary charges."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "9602a0a5",
   "metadata": {
    "lines_to_next_cell": 3
   },
   "outputs": [],
   "source": [
    "predictor.delete_predictor()"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  },
  "kernelspec": {
   "display_name": "sm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
